import os
import random
from tqdm import tqdm


seed = 64
folder = "./pretrain_dataset"
if not os.path.exists(folder):
    os.makedirs(folder)
train_filename = os.path.join(folder, "train.txt")
eval_filename = os.path.join(folder, "validation.txt")
p = 0.2
drop_p = 0.88   # 6000w数据选一部分训练，不然训不动

train_file = open(train_filename, "w", encoding="utf-8")
eval_file = open(eval_filename, "w", encoding="utf-8")

for filename in os.listdir("./extend_poi_dataset"):
    src = os.path.join("./extend_poi_dataset", filename)
    with open(src, "r", encoding="utf-8") as f:
        for line in tqdm(f.readlines()):
            if random.random() < drop_p:
                continue
            if random.random() <= p:
                eval_file.write(line)
            else:
                train_file.write(line)